{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from pathlib import Path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "image_dir = Path('../tests/data/mixed_images')"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Use one of the prepackaged models"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "from imagededup.methods import CNN\n",
    "\n",
    "# Get CustomModel construct\n",
    "from imagededup.utils import CustomModel\n",
    "\n",
    "# Get efficientnet model\n",
    "from imagededup.utils.models import EfficientNet\n",
    "# Other models include ViT, MobilenetV3 (default selection)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# Declare a custom config with CustomModel, the prepackaged models come with a name and transform function\n",
    "custom_config = CustomModel(name=EfficientNet.name,\n",
    "                            model=EfficientNet(),\n",
    "                            transform=EfficientNet.transform)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "cnn_encoder = CNN(model_config=custom_config)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "duplicates_cnn = cnn_encoder.find_duplicates(image_dir=image_dir, scores=True)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "duplicates_cnn"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "# User-defined model"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "from imagededup.methods import CNN\n",
    "\n",
    "# Get CustomModel construct\n",
    "from imagededup.utils import CustomModel\n",
    "\n",
    "# Import necessary pytorch constructs for initializing a custom feature extractor\n",
    "import torch\n",
    "from torchvision.transforms import transforms\n",
    "\n",
    "# Declare custom feature extractor class\n",
    "class MyModel(torch.nn.Module):\n",
    "    transform = transforms.Compose(\n",
    "        [\n",
    "            transforms.Resize((256, 256)),\n",
    "            transforms.ToTensor()\n",
    "        ]\n",
    "    ) # transform must take PIL.Image as input and return a torch.Tensor\n",
    "\n",
    "    name = 'my_custom_model' # name can be any user-defined string\n",
    "\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        # Define the layers of the model here\n",
    "\n",
    "    def forward(self, x):\n",
    "        # Add more operations here\n",
    "        x = x.view(-1, 256*256*3) # output shape: batch_size x features\n",
    "        return x"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# Initialize the CNN using model_config parameter and setting it to the custom model\n",
    "custom_config = CustomModel(name=MyModel.name,\n",
    "                            model=MyModel(),\n",
    "                            transform=MyModel.transform)\n",
    "\n",
    "cnn = CNN(model_config=custom_config)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "duplicates_cnn = cnn.find_duplicates(image_dir=image_dir, scores=True)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "duplicates_cnn"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Using a huggingface model"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "!pip install transformers"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "from imagededup.methods import CNN\n",
    "\n",
    "# Get CustomModel construct\n",
    "from imagededup.utils import CustomModel\n",
    "\n",
    "# Import necessary constructs for initializing a huggingface transformers model\n",
    "from transformers import ViTModel, AutoImageProcessor\n",
    "import torch\n",
    "from torchvision.transforms import transforms\n",
    "\n",
    "VIT_MODEL = \"google/vit-base-patch16-224-in21k\"\n",
    "\n",
    "def vit_transform(image):\n",
    "    transform = AutoImageProcessor.from_pretrained(VIT_MODEL)\n",
    "    x = transform(image, return_tensors = 'pt')['pixel_values']\n",
    "    return x\n",
    "\n",
    "class VitHgface(torch.nn.Module):\n",
    "    transform = transforms.Lambda(vit_transform)\n",
    "\n",
    "    name = 'ViT_hgface'\n",
    "\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.vit = ViTModel.from_pretrained(VIT_MODEL)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x  = x.view(-1, 3, 224, 224)\n",
    "        with torch.no_grad():\n",
    "            out = self.vit(pixel_values=x)\n",
    "        return out.pooler_output"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "custom_config = CustomModel(name=VitHgface.name,\n",
    "                            model=VitHgface(),\n",
    "                            transform=VitHgface.transform)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "cnn = CNN(model_config=custom_config)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "duplicates_cnn = cnn.find_duplicates(image_dir=image_dir, scores=True)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "duplicates_cnn"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "enc = cnn.encode_images(image_dir=image_dir)"
   ],
   "metadata": {
    "collapsed": false
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
